"""
Phase 1: Unified Panel Assembly for Nonlinear Framework
========================================================
Merge the full multilateral panel with moderator variables from:
  - safe_assets: safe_issuer indicator
  - monetary: QE indicator, income terciles
  - trilemma: EMU membership, exchange rate regime

Creates a single panel with all DVs and moderators needed for
the varying-coefficient estimation.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

ROOT = Path("/mnt/c/demographics_capital_flows")
PROJECT = ROOT / "nonlinear_framework"
OUT_DATA = PROJECT / "data" / "processed"
OUT_TABLES = PROJECT / "output" / "tables"
OUT_DATA.mkdir(parents=True, exist_ok=True)
OUT_TABLES.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(ROOT / "multilateral" / "src"))
from model import PanelGLS


def load_full_panel():
    """Load the 140-country multilateral panel."""
    path = ROOT / "multilateral" / "followup" / "data" / "processed" / "full_panel.csv"
    df = pd.read_csv(path)
    df = df[df['year'] <= 2024].copy()
    print(f"Full panel: {len(df)} obs, {df['iso3'].nunique()} countries")
    return df


def add_safe_issuer(df):
    """Merge safe-issuer indicator from safe_assets panel."""
    safe_path = ROOT / "safe_assets" / "data" / "processed" / "safe_asset_panel.csv"
    if not safe_path.exists():
        print("  Safe asset panel not found, constructing from ratings...")
        return _construct_safe_issuer(df)

    safe = pd.read_csv(safe_path)
    if 'safe_issuer' in safe.columns:
        safe_cols = safe[['iso3', 'year', 'safe_issuer']].drop_duplicates()
        df = df.merge(safe_cols, on=['iso3', 'year'], how='left')
        df['safe_issuer'] = df['safe_issuer'].fillna(0).astype(int)
        print(f"  Safe issuer: {df['safe_issuer'].sum()} safe-issuer obs")
    else:
        df = _construct_safe_issuer(df)
    return df


def _construct_safe_issuer(df):
    """Fallback: define safe issuers as consistently high-income OECD with strong institutions."""
    # Use a stable list of countries that have been AA- or above for most of the sample
    SAFE_ISSUERS = {
        'USA', 'DEU', 'GBR', 'FRA', 'JPN', 'CAN', 'AUS', 'CHE', 'NLD', 'AUT',
        'BEL', 'DNK', 'FIN', 'NOR', 'SWE', 'NZL', 'SGP', 'HKG', 'LUX',
        'KOR', 'TWN', 'ISR', 'QAT', 'KWT', 'ARE', 'CZE',
    }
    df['safe_issuer'] = df['iso3'].isin(SAFE_ISSUERS).astype(int)
    print(f"  Safe issuer (fallback): {df['safe_issuer'].sum()} obs")
    return df


def add_qe_indicator(df):
    """Merge QE indicator from monetary panel."""
    mon_path = ROOT / "monetary" / "data" / "processed" / "monetary_panel.csv"
    if not mon_path.exists():
        print("  Monetary panel not found, constructing QE from definitions...")
        return _construct_qe(df)

    mon = pd.read_csv(mon_path)
    qe_cols = []
    for col in ['qe_active', 'qe_country', 'post_qe_tightening']:
        if col in mon.columns:
            qe_cols.append(col)

    if qe_cols:
        mon_sub = mon[['iso3', 'year'] + qe_cols].drop_duplicates()
        df = df.merge(mon_sub, on=['iso3', 'year'], how='left')
        for col in qe_cols:
            df[col] = df[col].fillna(0).astype(int)
        print(f"  QE active: {df.get('qe_active', pd.Series([0])).sum()} obs")
    else:
        df = _construct_qe(df)
    return df


def _construct_qe(df):
    """Construct QE indicator from known episodes."""
    QE_EPISODES = {
        'USA': (2008, 2014), 'GBR': (2009, 2022), 'JPN': (2001, 2024),
        'SWE': (2015, 2019), 'CHE': (2015, 2022),
    }
    EMU_QE = (2015, 2022)
    EMU_MEMBERS = {
        'AUT', 'BEL', 'CYP', 'DEU', 'ESP', 'EST', 'FIN', 'FRA', 'GRC',
        'IRL', 'ITA', 'LTU', 'LUX', 'LVA', 'MLT', 'NLD', 'PRT', 'SVK', 'SVN',
    }

    df['qe_active'] = 0
    for iso, (start, end) in QE_EPISODES.items():
        df.loc[(df['iso3'] == iso) & (df['year'] >= start) & (df['year'] <= end), 'qe_active'] = 1
    for iso in EMU_MEMBERS:
        df.loc[(df['iso3'] == iso) & (df['year'] >= EMU_QE[0]) & (df['year'] <= EMU_QE[1]), 'qe_active'] = 1

    df['qe_country'] = df.groupby('iso3')['qe_active'].transform('max')
    df['post_qe_tightening'] = 0
    for iso in list(QE_EPISODES.keys()) + list(EMU_MEMBERS):
        df.loc[(df['iso3'] == iso) & (df['year'] >= 2022), 'post_qe_tightening'] = 1

    print(f"  QE active (constructed): {df['qe_active'].sum()} obs")
    return df


def add_emu_membership(df):
    """Add EMU membership from trilemma panel or construct."""
    tri_path = ROOT / "trilemma" / "data" / "processed" / "trilemma_panel.csv"
    if tri_path.exists():
        tri = pd.read_csv(tri_path)
        if 'eurozone' in tri.columns:
            emu_cols = tri[['iso3', 'year', 'eurozone']].drop_duplicates()
            # Also grab is_oecd if available
            for col in ['is_oecd', 'oecd_floater']:
                if col in tri.columns:
                    emu_cols = tri[['iso3', 'year', 'eurozone'] +
                                   [c for c in ['is_oecd', 'oecd_floater'] if c in tri.columns]].drop_duplicates()
                    break
            df = df.merge(emu_cols, on=['iso3', 'year'], how='left')
            for col in ['eurozone', 'is_oecd', 'oecd_floater']:
                if col in df.columns:
                    df[col] = df[col].fillna(0).astype(int)
            print(f"  EMU: {df.get('eurozone', pd.Series([0])).sum()} obs")
            return df

    # Fallback
    EUROZONE_JOIN = {
        'AUT': 1999, 'BEL': 1999, 'FIN': 1999, 'FRA': 1999, 'DEU': 1999,
        'IRL': 1999, 'ITA': 1999, 'LUX': 1999, 'NLD': 1999, 'PRT': 1999,
        'ESP': 1999, 'GRC': 2001, 'SVN': 2007, 'CYP': 2008, 'MLT': 2008,
        'SVK': 2009, 'EST': 2011, 'LVA': 2014, 'LTU': 2015,
    }
    df['eurozone'] = 0
    for iso, yr in EUROZONE_JOIN.items():
        df.loc[(df['iso3'] == iso) & (df['year'] >= yr), 'eurozone'] = 1
    print(f"  EMU (constructed): {df['eurozone'].sum()} obs")
    return df


def add_income_groups(df):
    """Create income terciles and OECD indicator."""
    # Income terciles from median GDP per capita
    median_gdppc = df.groupby('iso3')['gdp_pc_ppp'].median()
    tercile_cuts = median_gdppc.quantile([1/3, 2/3])
    low_cut, high_cut = tercile_cuts.iloc[0], tercile_cuts.iloc[1]

    country_tercile = pd.Series('middle', index=median_gdppc.index)
    country_tercile[median_gdppc <= low_cut] = 'low'
    country_tercile[median_gdppc > high_cut] = 'high'

    df['income_tercile'] = df['iso3'].map(country_tercile)
    df['income_low'] = (df['income_tercile'] == 'low').astype(int)
    df['income_high'] = (df['income_tercile'] == 'high').astype(int)

    # OECD indicator (if not already present)
    if 'is_oecd' not in df.columns:
        OECD = {
            'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
            'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
            'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
            'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
        }
        df['is_oecd'] = df['iso3'].isin(OECD).astype(int)

    n_low = df[df['income_tercile'] == 'low']['iso3'].nunique()
    n_mid = df[df['income_tercile'] == 'middle']['iso3'].nunique()
    n_high = df[df['income_tercile'] == 'high']['iso3'].nunique()
    print(f"  Income terciles: low={n_low}, middle={n_mid}, high={n_high} countries")
    print(f"  OECD: {df['is_oecd'].sum()} obs ({df[df['is_oecd']==1]['iso3'].nunique()} countries)")
    return df


def add_kaopen_groups(df):
    """Create KAOPEN saturation indicator and continuous moderator."""
    # KAOPEN saturation: at or near ceiling
    if 'kaopen' in df.columns:
        kaopen_max = df['kaopen'].max()
        df['kaopen_saturated'] = (df['kaopen'] >= kaopen_max * 0.95).astype(int)
        # KAOPEN terciles
        kaopen_median = df.groupby('iso3')['kaopen'].median()
        kaopen_cuts = kaopen_median.quantile([1/3, 2/3])
        country_kaopen = pd.Series('mid_open', index=kaopen_median.index)
        country_kaopen[kaopen_median <= kaopen_cuts.iloc[0]] = 'closed'
        country_kaopen[kaopen_median > kaopen_cuts.iloc[1]] = 'open'
        df['kaopen_group'] = df['iso3'].map(country_kaopen)
        print(f"  KAOPEN saturated: {df['kaopen_saturated'].sum()} obs")
    return df


def add_oadr_spline(df):
    """Create OADR spline terms for threshold analysis."""
    if 'old_dep' in df.columns:
        # Spline at OADR = 15% and 25% (from japanification and fiscal dominance findings)
        for knot in [15, 20, 25]:
            df[f'oadr_above_{knot}'] = np.maximum(df['old_dep'] * 100 - knot, 0)
        print(f"  OADR splines created at 15%, 20%, 25%")
    return df


def build_interactions(df):
    """Create all Z × moderator interaction terms."""
    z_vars = ['Z_1', 'Z_2', 'Z_3']
    moderators = {
        'safe_issuer': 'safe',
        'qe_active': 'qe',
        'eurozone': 'emu',
        'income_low': 'low',
        'income_high': 'high',
        'is_oecd': 'oecd',
        'kaopen_saturated': 'ksat',
    }

    for mod_var, mod_label in moderators.items():
        if mod_var in df.columns:
            for z in z_vars:
                col_name = f'{z}_x_{mod_label}'
                df[col_name] = df[z] * df[mod_var]

    # Continuous KAOPEN interactions (already in panel but rebuild to ensure)
    if 'kaopen' in df.columns:
        for z in z_vars:
            df[f'{z}_x_kaopen'] = df[z] * df['kaopen']

    # OADR spline interactions
    for knot in [15, 20, 25]:
        spline_col = f'oadr_above_{knot}'
        if spline_col in df.columns:
            for z in z_vars:
                df[f'{z}_x_oadr{knot}'] = df[z] * df[spline_col]

    n_interactions = len([c for c in df.columns if '_x_' in c])
    print(f"  Interaction terms: {n_interactions} created")
    return df


def write_summary_stats(df):
    """Write summary statistics table."""
    dvs = ['ca_gdp', 'govt_bond_10y', 'gross_investment_gdp', 'fiscal_bal_gdp']
    moderators = ['safe_issuer', 'qe_active', 'eurozone', 'is_oecd',
                  'income_low', 'income_high', 'kaopen', 'old_dep']
    z_vars = ['Z_1', 'Z_2', 'Z_3']

    all_vars = z_vars + dvs + moderators
    stats = []
    for var in all_vars:
        if var in df.columns:
            s = df[var].describe()
            stats.append({
                'Variable': var,
                'N': int(s['count']),
                'Mean': f"{s['mean']:.3f}",
                'SD': f"{s['std']:.3f}",
                'Min': f"{s['min']:.3f}",
                'Max': f"{s['max']:.3f}",
            })

    stats_df = pd.DataFrame(stats)

    # Write markdown
    with open(OUT_TABLES / "phase1_summary_stats.md", 'w') as f:
        f.write("# Phase 1: Unified Panel Summary Statistics\n\n")
        f.write(stats_df.to_markdown(index=False))
        f.write(f"\n\n*Panel: {len(df)} obs, {df['iso3'].nunique()} countries, "
                f"{df['year'].min()}-{df['year'].max()}*\n")
    print(f"  Wrote: phase1_summary_stats.md")

    # Moderator cross-tabulation
    with open(OUT_TABLES / "phase1_moderator_crosstab.md", 'w') as f:
        f.write("# Phase 1: Moderator Cross-Tabulation\n\n")

        # Countries by income × safe × OECD
        cross = df.groupby('iso3').agg({
            'income_tercile': 'first',
            'safe_issuer': 'max',
            'is_oecd': 'max',
            'qe_active': 'max',
            'eurozone': 'max',
        }).reset_index()

        f.write("## Country Counts by Moderator\n\n")
        f.write(f"| Moderator | Count | % of Countries |\n")
        f.write(f"|---|---|---|\n")
        n_total = len(cross)
        for mod, label in [('is_oecd', 'OECD'), ('safe_issuer', 'Safe Issuer'),
                           ('qe_active', 'QE Country'), ('eurozone', 'EMU')]:
            n = cross[mod].sum()
            f.write(f"| {label} | {n} | {100*n/n_total:.1f}% |\n")

        f.write(f"\n## Income Tercile × OECD\n\n")
        ct = pd.crosstab(cross['income_tercile'], cross['is_oecd'],
                         margins=True)
        ct.columns = ['Non-OECD', 'OECD', 'Total']
        f.write(ct.to_markdown())
        f.write("\n")

    print(f"  Wrote: phase1_moderator_crosstab.md")


def main():
    print("Phase 1: Unified Panel Assembly")
    print("=" * 70)

    df = load_full_panel()

    print("\nAdding moderator variables...")
    df = add_safe_issuer(df)
    df = add_qe_indicator(df)
    df = add_emu_membership(df)
    df = add_income_groups(df)
    df = add_kaopen_groups(df)
    df = add_oadr_spline(df)

    print("\nBuilding interaction terms...")
    df = build_interactions(df)

    print("\nWriting summary statistics...")
    write_summary_stats(df)

    # Save
    out_path = OUT_DATA / "unified_panel.csv"
    df.to_csv(out_path, index=False)
    print(f"\nSaved: {out_path}")
    print(f"  {len(df)} obs, {df['iso3'].nunique()} countries, {len(df.columns)} columns")

    print("\n" + "=" * 70)
    print("Phase 1 complete.")


if __name__ == '__main__':
    main()
